Original DNS Exfiltration Detection Model
Uses deep learning to detect DNS exfiltration attempts.
"""

# Import shared theme configuration
from theme_config import console

import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from rich.progress import Progress, SpinnerColumn, TextColumn
import string
import math
from collections import Counter
from typing import Dict, List, Any, Optional

# Constants
MODEL_DIRECTORY = "/content/models"
PRINTABLE_CHARS = list(string.printable.strip())
BATCH_SIZE = 256
THRESHOLD = 0.5

class DNSExfiltrationDetector(nn.Module):
    """Neural network model for DNS exfiltration detection."""
    
    def __init__(self, input_size: int = 98):
        """Initialize the model architecture.
        
        Args:
            input_size: Dimension of input features
        """
        super().__init__()
        self.layer_1 = nn.Linear(input_size, 128)
        self.layer_2 = nn.Linear(128, 128)
        self.layer_out = nn.Linear(128, 1)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(p=0.5)
        
        console.print("[info]Model architecture initialized[/]")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network."""
        x = self.relu(self.layer_1(x))
        x = self.dropout(x)
        x = self.relu(self.layer_2(x))
        x = self.dropout(x)
        x = self.sigmoid(self.layer_out(x))
        return x

class DNSFeatureProcessor:
    """Handles DNS query feature processing."""
    
    def __init__(self):
        self.text_rows = []
        self.size_avg = []
        self.entropy_avg = []
        self.scaler = StandardScaler()
    
    def clear_state(self):
        """Clear accumulated state."""
        self.text_rows.clear()
        self.size_avg.clear()
        self.entropy_avg.clear()
    
    def compute_entropy(self, domain: str) -> float:
        """Calculate Shannon entropy of domain string."""
        char_counts = Counter(domain)
        domain_len = float(len(domain))
        return -sum(count / domain_len * math.log(count / domain_len, 2) 
                   for count in char_counts.values())
    
    def index_chars(self, query: str) -> None:
        """Convert query string to character frequency dictionary."""
        char_freq = {}
        for char in query:
            if char in PRINTABLE_CHARS:
                idx = PRINTABLE_CHARS.index(char)
                char_freq[idx] = char_freq.get(idx, 0) + 1
        self.text_rows.append(char_freq)
    
    def compute_aggregated_features(self, row: pd.Series, df: pd.DataFrame) -> None:
        """Compute historical features for a DNS query."""
        prev_events = df[(df['src'] == row['src']) & (df['tld'] == row['tld'])]
        self.size_avg.append(prev_events['len'].mean())
        self.entropy_avg.append(prev_events['entropy'].mean())
    
    def prepare_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Process DNS queries and extract features."""
        console.print("[info]Extracting features from DNS queries...[/]")
        
        # Extract character frequencies
        df['query'].apply(self.index_chars)
        char_freq_df = pd.DataFrame(self.text_rows, columns=range(len(PRINTABLE_CHARS)))
        char_freq_df.fillna(0, inplace=True)
        
        # Add character frequency features
        df = pd.concat([char_freq_df, df.reset_index(drop=True)], axis=1)
        
        # Extract domain parts
        df['subdomain'] = df['query'].apply(lambda x: str(x).rsplit('.', 2)[0])
        df['tld'] = df['query'].apply(lambda x: '.'.join(str(x).rsplit('.', 2)[1:]))
        
        # Calculate basic features
        df['len'] = df['subdomain'].apply(len)
        df['entropy'] = df['subdomain'].apply(self.compute_entropy)
        
        # Calculate aggregated features for most recent queries
        recent_df = df[df['rank'] == 1].copy()
        recent_df.apply(lambda x: self.compute_aggregated_features(x, df), axis=1)
        recent_df['size_avg'] = self.size_avg
        recent_df['entropy_avg'] = self.entropy_avg
        
        console.print("[success]Feature extraction complete[/]")
        return recent_df

class DNSExfiltrationModel:
    """Main class for DNS exfiltration detection."""
    
    def __init__(self):
        self.model = DNSExfiltrationDetector()
        self.feature_processor = DNSFeatureProcessor()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        console.print(f"[info]Using device: {self.device}[/]")
    
    def train(self, df: pd.DataFrame, epochs: int = 100, 
             learning_rate: float = 0.001) -> Dict[str, List[float]]:
        """Train the model on DNS query data."""
        console.print("[info]Starting model training...[/]")
        
        # Process features
        processed_df = self.feature_processor.prepare_features(df)
        features = processed_df.drop(['src', 'query', 'rank', 'subdomain', 'tld'], axis=1)
        
        # Prepare training data
        X = torch.FloatTensor(features.values).to(self.device)
        y = torch.FloatTensor(df['is_exfiltration'].values).to(self.device)
        
        # Training setup
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        criterion = nn.BCELoss()
        dataloader = DataLoader(TensorDataset(X, y), batch_size=BATCH_SIZE, shuffle=True)
        
        # Training progress
        history = {'loss': []}
        
        with Progress(
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            console=console
        ) as progress:
            task = progress.add_task("[cyan]Training...", total=epochs)
            
            for epoch in range(epochs):
                epoch_loss = 0
                for batch_X, batch_y in dataloader:
                    optimizer.zero_grad()
                    outputs = self.model(batch_X)
                    loss = criterion(outputs.squeeze(), batch_y)
                    loss.backward()
                    optimizer.step()
                    epoch_loss += loss.item()
                
                history['loss'].append(epoch_loss / len(dataloader))
                progress.update(task, advance=1)
        
        console.print("[success]Training complete![/]")
        return history
    
    def predict(self, df: pd.DataFrame) -> pd.DataFrame:
        """Make predictions on new data."""
        console.print("[info]Making predictions...[/]")
        
        # Process features
        processed_df = self.feature_processor.prepare_features(df)
        features = processed_df.drop(['src', 'query', 'rank', 'subdomain', 'tld'], axis=1)
        
        # Make predictions
        self.model.eval()
        with torch.no_grad():
            X = torch.FloatTensor(features.values).to(self.device)
            outputs = self.model(X)
            predictions = outputs.cpu().numpy()
        
        # Add predictions to dataframe
        df['pred_is_dns_data_exfiltration_proba'] = predictions
        df['pred_is_dns_data_exfiltration'] = predictions >= THRESHOLD
        
        console.print("[success]Predictions complete![/]")
        return df
    
    def save(self, path: str):
        """Save model state."""
        torch.save(self.model.state_dict(), path)
        console.print(f"[success]Model saved to {path}[/]")
    
    def load(self, path: str):
        """Load model state."""
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        console.print(f"[success]Model loaded from {path}[/]")

def init(df: pd.DataFrame, param: Dict[str, Any]) -> DNSExfiltrationModel:
    """Initialize the model."""
    model = DNSExfiltrationModel()
    if os.path.exists(os.path.join(MODEL_DIRECTORY, 'original_dns_model.pt')):
        model.load('original_dns_model.pt')
    return model

def fit(model: DNSExfiltrationModel, df: pd.DataFrame, 
        param: Dict[str, Any]) -> Dict[str, str]:
    """Train the model on new data."""
    history = model.train(df)
    model.save('original_dns_model.pt')
    return {"message": "Model trained successfully", "history": str(history)}

def apply(model: DNSExfiltrationModel, df: pd.DataFrame, 
         param: Dict[str, Any]) -> pd.DataFrame:
    """Apply the model to new data."""
    return model.predict(df)